Synphony
Deep Learning Final Project - MSDS Spring Module 2 - 2025
Aditi Puttur & Emma Juan
1. Data Preprocessing¶
In [2]:
import pandas as pd
import numpy as np
import os
import json
from tqdm import tqdm
import re
import unicodedata
import warnings
warnings.filterwarnings("ignore")
from miditok import REMI, TokenizerConfig, TokSequence
from miditoolkit import MidiFile
from symusic import Score
os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import ReduceLROnPlateau
import math
from typing import Optional
import traceback
Loading the data¶
LMD: Midi Files¶
In [3]:
# Open and read the JSON file
with open('data/LMD/md5_to_paths.json', 'r') as file:
md5_to_paths = json.load(file)
In [7]:
md5_to_paths['1c83fc02b8c57fbc2605900bb31793fb']
Out[7]:
['E/Exaltasamba - Megastar.mid', 'Midis Samba e Pagode/Exaltasamba - Megastar.mid', 'Midis Samba e Pagode/Exaltasamba - Megastar.mid']
In [9]:
lmd_catalog = []
for dirpath, dirnames, filenames in os.walk('data/LMD/lmd_matched'):
for file in filenames:
full_path = os.path.join(dirpath, file)
if full_path.endswith('.mid'):
lmd_catalog.append(full_path)
In [10]:
lmd_catalog.sort()
lmd_catalog[:10]
Out[10]:
['data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/1d9d16a9da90c090809c153754823c2b.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/5dd29e99ed7bd3cc0c5177a6e9de22ea.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/b97c529ab9ef783a849b896816001748.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/dac3cdd0db6341d8dc14641e44ed0d44.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/05f21994c71a5f881e64f45c8d706165.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/10288ea8e07b70c17f872fda82b94330.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/6304d2bba4282f8bd74322828c30f0c7.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/c24989559d170135b9c6546d1d2df20b.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/ddb6a3db65461dca1a43de72f5375d8b.mid', 'data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/dfea6fd75926c571a87db789280d059d.mid']
In [6]:
len(lmd_catalog)
Out[6]:
116189
In [7]:
lmd_catalog_all = {'path': [],
'MSD_name': [],
'LMD_name': []}
lmd_catalog_all['path'] = lmd_catalog
lmd_catalog_all['MSD_name'] = [path.split('/')[-2] for path in lmd_catalog]
lmd_catalog_all['LMD_name'] = [path.split('/')[-1].split('.')[-2] for path in lmd_catalog]
lmd_df = pd.DataFrame(lmd_catalog_all)
lmd_df
Out[7]:
| path | MSD_name | LMD_name | |
|---|---|---|---|
| 0 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | TRAAAGR128F425B14B | 1d9d16a9da90c090809c153754823c2b |
| 1 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | TRAAAGR128F425B14B | 5dd29e99ed7bd3cc0c5177a6e9de22ea |
| 2 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | TRAAAGR128F425B14B | b97c529ab9ef783a849b896816001748 |
| 3 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | TRAAAGR128F425B14B | dac3cdd0db6341d8dc14641e44ed0d44 |
| 4 | data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... | TRAAAZF12903CCCF6B | 05f21994c71a5f881e64f45c8d706165 |
| ... | ... | ... | ... |
| 116184 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | TRZZZTN128EF35C42F | 165e156e5192569e41dc8390b80a1465 |
| 116185 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | TRZZZTN128EF35C42F | 87e403b5fcb06718767aee0a9386f86c |
| 116186 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | TRZZZTN128EF35C42F | c56e00ecc890dfdfbdd551cb9ea15ca5 |
| 116187 | data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... | TRZZZYV128F92E996D | 1b966417a9aa703873c5fa1cfe18da32 |
| 116188 | data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... | TRZZZYV128F92E996D | 3bcd7e0cc20adcc8dc3e912623bb0e1b |
116189 rows × 3 columns
In [8]:
lmd_df["MSD_name"].nunique()
Out[8]:
31034
LMD-matched metadata (MillionSongDataset): The Metadata¶
In [11]:
import hdf5_getters
In [12]:
msd_catalog = []
titles = []
artists = []
releases = []
years = []
for dirpath, dirnames, filenames in tqdm(os.walk('data/LMD-matched-MSD')):
for file in filenames:
full_path = os.path.join(dirpath, file)
if full_path.endswith('.h5'):
# Append the path to the list
msd_catalog.append(full_path)
# Get the metadata
h5 = hdf5_getters.open_h5_file_read(full_path)
titles.append(hdf5_getters.get_title(h5))
artists.append(hdf5_getters.get_artist_name(h5))
releases.append(hdf5_getters.get_release(h5))
years.append(hdf5_getters.get_year(h5))
# danceability = hdf5_getters.get_danceability(h5)
# get_energy = hdf5_getters.get_energy(h5)
15298it [07:23, 34.52it/s]
In [13]:
msd_catalog[:10]
Out[13]:
['data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5', 'data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5', 'data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5', 'data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5', 'data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5', 'data/LMD-matched-MSD/R/R/I/TRRRIVC12903CA6C5A.h5', 'data/LMD-matched-MSD/R/R/I/TRRRILD128F92CB682.h5', 'data/LMD-matched-MSD/R/R/I/TRRRION128F145EBB7.h5', 'data/LMD-matched-MSD/R/R/N/TRRRNPV128F42AAA55.h5', 'data/LMD-matched-MSD/R/R/N/TRRRNGS12903CD16D9.h5']
In [12]:
len(msd_catalog)
Out[12]:
31034
In [13]:
len(msd_catalog) == lmd_df["MSD_name"].nunique()
Out[13]:
True
In [14]:
titles[:5]
Out[14]:
[b'Wastelands', b'Runaway', b'Have You Met Miss Jones? (Swing When Version)', b'Goodbye', b'La Colegiala']
In [15]:
artists[:5]
Out[15]:
[b'Hawkwind', b'Del Shannon', b'Robbie Williams', b'Volebeats', b'Rodolfo Y Su Tipica Ra7']
In [16]:
years[:5]
Out[16]:
[1994, 1961, 2001, 0, 1997]
In [17]:
titles = [title.decode('utf-8') for title in titles]
artists = [artist.decode('utf-8') for artist in artists]
In [18]:
msd_catalog_all = {'path': [],
'MSD_name': [],
'title': [],
'artist': [],
'year': []}
msd_catalog_all['path'] = msd_catalog
msd_catalog_all['title'] = titles
msd_catalog_all['artist'] = artists
msd_catalog_all['year'] = years
msd_catalog_all['MSD_name'] = [path.split('/')[-1].split('.')[-2] for path in msd_catalog]
msd_df = pd.DataFrame(msd_catalog_all)
msd_df
Out[18]:
| path | MSD_name | title | artist | year | |
|---|---|---|---|---|---|
| 0 | data/LMD-matched-MSD/R/R/U/TRRRUFD12903CD7092.h5 | TRRRUFD12903CD7092 | Wastelands | Hawkwind | 1994 |
| 1 | data/LMD-matched-MSD/R/R/U/TRRRUTV12903CEA11B.h5 | TRRRUTV12903CEA11B | Runaway | Del Shannon | 1961 |
| 2 | data/LMD-matched-MSD/R/R/U/TRRRUJO128E07813E7.h5 | TRRRUJO128E07813E7 | Have You Met Miss Jones? (Swing When Version) | Robbie Williams | 2001 |
| 3 | data/LMD-matched-MSD/R/R/I/TRRRIYO128F428CF6F.h5 | TRRRIYO128F428CF6F | Goodbye | Volebeats | 0 |
| 4 | data/LMD-matched-MSD/R/R/I/TRRRILO128F422FFED.h5 | TRRRILO128F422FFED | La Colegiala | Rodolfo Y Su Tipica Ra7 | 1997 |
| ... | ... | ... | ... | ... | ... |
| 31029 | data/LMD-matched-MSD/W/W/Y/TRWWYHD12903CC42B1.h5 | TRWWYHD12903CC42B1 | Gethsemane (I Only Want to Say) (Live-LP Version) | Michael Crawford | 0 |
| 31030 | data/LMD-matched-MSD/W/W/Y/TRWWYNJ128F426541F.h5 | TRWWYNJ128F426541F | Cold Feelings | Social Distortion | 1992 |
| 31031 | data/LMD-matched-MSD/W/W/P/TRWWPSV128F4244C71.h5 | TRWWPSV128F4244C71 | Ases Death | At Vance | 2001 |
| 31032 | data/LMD-matched-MSD/W/W/P/TRWWPBK128F42911E9.h5 | TRWWPBK128F42911E9 | Drowned Maid | Amorphis | 1993 |
| 31033 | data/LMD-matched-MSD/W/W/W/TRWWWUT128F9364D1A.h5 | TRWWWUT128F9364D1A | Ting-A-Ling | A Balladeer | 0 |
31034 rows × 5 columns
In [19]:
msd_df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 31034 entries, 0 to 31033 Data columns (total 5 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 path 31034 non-null object 1 MSD_name 31034 non-null object 2 title 31034 non-null object 3 artist 31034 non-null object 4 year 31034 non-null int32 dtypes: int32(1), object(4) memory usage: 1.1+ MB
tagtraum: Adding Genre Tags¶
In [20]:
tagtraum = {'MSD_name': [],
'genre': []}
with open("data/tagtraum/msd_tagtraum_cd2c.cls", "r") as file:
lines = file.readlines()
for line in lines:
if not line.startswith('#'):
track, genre = line.strip().split('\t')
tagtraum['MSD_name'].append(track)
tagtraum['genre'].append(genre)
In [21]:
tagtraum_df = pd.DataFrame(tagtraum)
tagtraum_df
Out[21]:
| MSD_name | genre | |
|---|---|---|
| 0 | TRAAAAK128F9318786 | Rock |
| 1 | TRAAAAW128F429D538 | Rap |
| 2 | TRAAADJ128F4287B47 | Rock |
| 3 | TRAAADZ128F9348C2E | Latin |
| 4 | TRAAAED128E0783FAB | Jazz |
| ... | ... | ... |
| 191396 | TRZZZMY128F426D7A2 | Reggae |
| 191397 | TRZZZRJ128F42819AF | Rock |
| 191398 | TRZZZUK128F92E3C60 | Folk |
| 191399 | TRZZZZD128F4236844 | Rock |
| 191400 | TRZZZZZ12903D05E3A | Electronic |
191401 rows × 2 columns
In [22]:
tagtraum_df["genre"].unique()
Out[22]:
array(['Rock', 'Rap', 'Latin', 'Jazz', 'Electronic', 'Pop', 'Metal',
'RnB', 'Country', 'Reggae', 'Blues', 'Folk', 'Punk', 'World',
'New Age'], dtype=object)
Creating our dataset: MIDI + Metadata + Genres¶
Midi + Metadata¶
Each track (MSD_name -> track_id) has one metadata file, and different MIDI files (LMD_name -> midi_id) associated with it.
In [23]:
len(lmd_df), len(msd_df)
Out[23]:
(116189, 31034)
In [24]:
lmd_df["MSD_name"].nunique(), len(msd_df)
Out[24]:
(31034, 31034)
In [25]:
dataset = lmd_df.merge(msd_df, how="inner", on="MSD_name", suffixes=('_lmd', '_msd'))
dataset = dataset.rename(columns={"path_lmd": "midi_filepath",
"path_msd": "metadata_filepath",
"MSD_name": "track_id",
"LMD_name": "midi_id"})
dataset = dataset[["track_id", "midi_id", "midi_filepath",
"title", "artist", "year"]]
dataset
Out[25]:
| track_id | midi_id | midi_filepath | title | artist | year | |
|---|---|---|---|---|---|---|
| 0 | TRAAAGR128F425B14B | 1d9d16a9da90c090809c153754823c2b | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 |
| 1 | TRAAAGR128F425B14B | 5dd29e99ed7bd3cc0c5177a6e9de22ea | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 |
| 2 | TRAAAGR128F425B14B | b97c529ab9ef783a849b896816001748 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 |
| 3 | TRAAAGR128F425B14B | dac3cdd0db6341d8dc14641e44ed0d44 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 |
| 4 | TRAAAZF12903CCCF6B | 05f21994c71a5f881e64f45c8d706165 | data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... | Break My Stride | Matthew Wilder | 1983 |
| ... | ... | ... | ... | ... | ... | ... |
| 116184 | TRZZZTN128EF35C42F | 165e156e5192569e41dc8390b80a1465 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | Funky Dance Music Vol 1 | DJ Rob E | 0 |
| 116185 | TRZZZTN128EF35C42F | 87e403b5fcb06718767aee0a9386f86c | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | Funky Dance Music Vol 1 | DJ Rob E | 0 |
| 116186 | TRZZZTN128EF35C42F | c56e00ecc890dfdfbdd551cb9ea15ca5 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | Funky Dance Music Vol 1 | DJ Rob E | 0 |
| 116187 | TRZZZYV128F92E996D | 1b966417a9aa703873c5fa1cfe18da32 | data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... | Dear Lie | TLC | 1999 |
| 116188 | TRZZZYV128F92E996D | 3bcd7e0cc20adcc8dc3e912623bb0e1b | data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... | Dear Lie | TLC | 1999 |
116189 rows × 6 columns
In [26]:
grouped_dataset = dataset.groupby('track_id').first().reset_index()
grouped_dataset = grouped_dataset[['track_id', 'midi_id', 'midi_filepath']]
grouped_dataset = grouped_dataset.merge(
dataset[
['track_id', "title", "artist", "year"]
].drop_duplicates(), on='track_id', how='left' )
grouped_dataset = grouped_dataset[["track_id", "midi_id", "midi_filepath",
"title", "artist", "year"]]
grouped_dataset
Out[26]:
| track_id | midi_id | midi_filepath | title | artist | year | |
|---|---|---|---|---|---|---|
| 0 | TRAAAGR128F425B14B | 1d9d16a9da90c090809c153754823c2b | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 |
| 1 | TRAAAZF12903CCCF6B | 05f21994c71a5f881e64f45c8d706165 | data/LMD/lmd_matched/A/A/A/TRAAAZF12903CCCF6B/... | Break My Stride | Matthew Wilder | 1983 |
| 2 | TRAABVM128F92CA9DC | 0dd4d2b9fbcf96a0fa363a1918255e58 | data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... | Caught In A Dream | Tesla | 2004 |
| 3 | TRAABXH128F42955D6 | 01ffb8729a2465bfa7f9ba0288c89e24 | data/LMD/lmd_matched/A/A/B/TRAABXH128F42955D6/... | Keep An Eye On Summer (Album Version) | Brian Wilson | 1998 |
| 4 | TRAACQE12903CC706C | 1ee7c9ad5f18b2659789d9608c951ca5 | data/LMD/lmd_matched/A/A/C/TRAACQE12903CC706C/... | Summer | Old Man River | 2007 |
| ... | ... | ... | ... | ... | ... | ... |
| 31029 | TRZZYLO12903CAC06C | 128551e12d6dec38ad7ce00665c77fe5 | data/LMD/lmd_matched/Z/Z/Y/TRZZYLO12903CAC06C/... | I've Never Seen The Righteous Forsaken | Dallas Holm | 0 |
| 31030 | TRZZYTX128F92EBE33 | 538838021299e65875a8bec61a87a368 | data/LMD/lmd_matched/Z/Z/Y/TRZZYTX128F92EBE33/... | I Don't Want To Do It (2009 Digital Remaster) | George Harrison | 0 |
| 31031 | TRZZZBU128F426811B | 0702ddab7728f7b0e51321d8a0366367 | data/LMD/lmd_matched/Z/Z/Z/TRZZZBU128F426811B/... | Dame Una Se size= | Los Iracundos | 0 |
| 31032 | TRZZZTN128EF35C42F | 165e156e5192569e41dc8390b80a1465 | data/LMD/lmd_matched/Z/Z/Z/TRZZZTN128EF35C42F/... | Funky Dance Music Vol 1 | DJ Rob E | 0 |
| 31033 | TRZZZYV128F92E996D | 1b966417a9aa703873c5fa1cfe18da32 | data/LMD/lmd_matched/Z/Z/Z/TRZZZYV128F92E996D/... | Dear Lie | TLC | 1999 |
31034 rows × 6 columns
Adding the genre tags¶
In [27]:
dataset = dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
dataset = dataset.drop(columns=["MSD_name"])
dataset
Out[27]:
| track_id | midi_id | midi_filepath | title | artist | year | genre | |
|---|---|---|---|---|---|---|---|
| 0 | TRAAAGR128F425B14B | 1d9d16a9da90c090809c153754823c2b | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 | Pop |
| 1 | TRAAAGR128F425B14B | 5dd29e99ed7bd3cc0c5177a6e9de22ea | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 | Pop |
| 2 | TRAAAGR128F425B14B | b97c529ab9ef783a849b896816001748 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 | Pop |
| 3 | TRAAAGR128F425B14B | dac3cdd0db6341d8dc14641e44ed0d44 | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 | Pop |
| 4 | TRAABVM128F92CA9DC | 0dd4d2b9fbcf96a0fa363a1918255e58 | data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... | Caught In A Dream | Tesla | 2004 | Rock |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 21348 | TRZZROL12903CAC4A8 | 0f0aaf2f90bc66da732f4371e703eae4 | data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/... | Love Love | Amy MacDonald | 2010 | Pop |
| 21349 | TRZZSML12903CBB7BD | bc4aae694e7c433a6da16284e52e11be | data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/... | Airwave (Radio Edit) | Rank 1 | 2000 | Electronic |
| 21350 | TRZZTHP128F427F139 | b085f5c3571f570bdc44fa0c9b6a0672 | data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... | Briaris | The Sweetest Ache | 1992 | Rock |
| 21351 | TRZZTHP128F427F139 | f10a54a5e8b4d169eec5231bb6b15c94 | data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... | Briaris | The Sweetest Ache | 1992 | Rock |
| 21352 | TRZZXJE12903CD1D93 | 7723a2ff572a0b49f9d0e552313f7db7 | data/LMD/lmd_matched/Z/Z/X/TRZZXJE12903CD1D93/... | Warm and Tender Love | Percy Sledge | 1967 | RnB |
21353 rows × 7 columns
In [28]:
grouped_dataset = grouped_dataset.merge(tagtraum_df, how="inner", left_on="track_id", right_on="MSD_name")
grouped_dataset = grouped_dataset.drop(columns=["MSD_name"])
grouped_dataset
Out[28]:
| track_id | midi_id | midi_filepath | title | artist | year | genre | |
|---|---|---|---|---|---|---|---|
| 0 | TRAAAGR128F425B14B | 1d9d16a9da90c090809c153754823c2b | data/LMD/lmd_matched/A/A/A/TRAAAGR128F425B14B/... | Into The Nightlife | Cyndi Lauper | 2008 | Pop |
| 1 | TRAABVM128F92CA9DC | 0dd4d2b9fbcf96a0fa363a1918255e58 | data/LMD/lmd_matched/A/A/B/TRAABVM128F92CA9DC/... | Caught In A Dream | Tesla | 2004 | Rock |
| 2 | TRAAGMC128F4292D0F | 0644195d1a3d14e0a0bd3d8b30dc68da | data/LMD/lmd_matched/A/A/G/TRAAGMC128F4292D0F/... | My Love (Album Version) | LITTLE TEXAS | 0 | Country |
| 3 | TRAANZE128F148BF55 | 0597bf18743a5aacfedc981eb58c9da9 | data/LMD/lmd_matched/A/A/N/TRAANZE128F148BF55/... | The Name Of The Game | Abba | 1977 | Pop |
| 4 | TRAAPPQ128F14961F5 | d39a20f33af4fb6b307529db8cf0cc3f | data/LMD/lmd_matched/A/A/P/TRAAPPQ128F14961F5/... | Wig | The B-52's | 1986 | Rock |
| ... | ... | ... | ... | ... | ... | ... | ... |
| 6175 | TRZZQGM128F9311E60 | 34d27fedd8dca07e36f50d69ba477e5b | data/LMD/lmd_matched/Z/Z/Q/TRZZQGM128F9311E60/... | Sun Of Jamaica | Goombay Dance Band | 1991 | Pop |
| 6176 | TRZZROL12903CAC4A8 | 0f0aaf2f90bc66da732f4371e703eae4 | data/LMD/lmd_matched/Z/Z/R/TRZZROL12903CAC4A8/... | Love Love | Amy MacDonald | 2010 | Pop |
| 6177 | TRZZSML12903CBB7BD | bc4aae694e7c433a6da16284e52e11be | data/LMD/lmd_matched/Z/Z/S/TRZZSML12903CBB7BD/... | Airwave (Radio Edit) | Rank 1 | 2000 | Electronic |
| 6178 | TRZZTHP128F427F139 | b085f5c3571f570bdc44fa0c9b6a0672 | data/LMD/lmd_matched/Z/Z/T/TRZZTHP128F427F139/... | Briaris | The Sweetest Ache | 1992 | Rock |
| 6179 | TRZZXJE12903CD1D93 | 7723a2ff572a0b49f9d0e552313f7db7 | data/LMD/lmd_matched/Z/Z/X/TRZZXJE12903CD1D93/... | Warm and Tender Love | Percy Sledge | 1967 | RnB |
6180 rows × 7 columns
Sluggifying our parameters¶
In [29]:
genres = dataset["genre"].unique()
artists = dataset["artist"].unique()
years = dataset["year"].unique()
In [30]:
def slug(text: str) -> str:
"""Return an ALL_CAPS alnum/underscore version of `text`."""
# 1) strip accents → ascii
text = unicodedata.normalize("NFKD", text).encode("ascii", "ignore").decode()
# 2) replace non‑alnum with underscore
text = re.sub(r"[^\w]+", "_", text)
# 3) collapse multiple underscores and upper‑case
return re.sub(r"_+", "_", text).strip("_").upper()
In [31]:
genres_slugged = np.array([slug(genre) for genre in genres])
artists_slugged = np.array([slug(artist) for artist in artists])
years = np.array([int(year) for year in years if not pd.isna(year)])
In [32]:
genres = pd.DataFrame({
'genre': genres,
'slugged_genre': genres_slugged
})
artists = pd.DataFrame({
'artist': artists,
'slugged_artist': artists_slugged
})
years = pd.DataFrame({
'year': years
})
In [33]:
genres = genres.sort_values(by='genre')
artists = artists.sort_values(by='artist')
years = years.sort_values(by='year')
In [34]:
dataset["slugged_genre"] = dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
dataset["slugged_artist"] = dataset["artist"].map(artists.set_index('artist')['slugged_artist'])
grouped_dataset["slugged_genre"] = grouped_dataset["genre"].map(genres.set_index('genre')['slugged_genre'])
grouped_dataset["slugged_artist"] = grouped_dataset["artist"].map(artists.set_index('artist')['slugged_artist'])
Saving our data¶
Saving the metadata datasets¶
In [35]:
dataset.to_csv("data/metadata.csv", index=False)
In [36]:
grouped_dataset.to_csv("data/grouped_metadata.csv", index=False)
Saving the different parameters to csvs¶
In [37]:
genres.to_csv("data/genres.csv", index=False)
artists.to_csv("data/artists.csv", index=False)
years.to_csv("data/years.csv", index=False)
2. Model Implementation¶
In [85]:
dataset = pd.read_csv("data/metadata.csv")
grouped_dataset = pd.read_csv("data/grouped_metadata.csv")
genres = pd.read_csv("data/genres.csv")
titles = pd.read_csv("data/titles.csv")
artists = pd.read_csv("data/artists.csv")
years = pd.read_csv("data/years.csv")
In [86]:
genres_slugged = genres["slugged_genre"].values
artists_slugged = artists["slugged_artist"].values
years_vals = years["year"].values
In [87]:
# Config whith which the model was trained
# MAX_TOKENS = 512
# BATCH_SIZE = 2
# D_MODEL = 512
# N_LAYERS = 6
# N_HEADS = 8
# New config to try
MAX_TOKENS = 1024
BATCH_SIZE = 8
D_MODEL = 768
N_LAYERS = 8
N_HEADS = 12 # 768 / 12 = 64 per head
Tokenization¶
Defining the tokenizer¶
In [88]:
config = TokenizerConfig(
num_velocities=32,
use_chords=True,
use_programs=True,
beat_res={(0,4): 8, (4,8): 4},
use_rests=True,
rest_range=(2,8),
use_time_signatures=True
)
tokenizer = REMI(config)
Adding our special tokens¶
In [89]:
special_toks = \
[f"<GENRE_{g}>" for g in genres_slugged] + \
[f"<ARTIST_{a}>" for a in artists_slugged] + \
[f"<YEAR_{y}>" for y in years_vals] + \
["<EOS>", "<PAD>"]
for tok in special_toks:
tokenizer.add_to_vocab(tok)
Tokenizing: Storing each track as a numpy int32 array.¶
In [90]:
tokenizing = False
In [91]:
# ─── 1. Helpers ──────────────────────────────────────────────────────────
def build_prefix(genre, artist, year, tokenizer):
"""Convert metadata row → list[int] conditioning tokens."""
genre_tok = f"<GENRE_{genre}>"
artist_tok = f"<ARTIST_{artist}>"
year_tok = f"<YEAR_{year}>"
# NOTE: use tokenizer.vocab[...] (or .token_to_id(...))
return [
tokenizer.vocab[genre_tok],
tokenizer.vocab[artist_tok],
tokenizer.vocab[year_tok],
]
# ─── 3. Output directory -------------------------------------------------
out_dir = "data/tokens/train"
# ─── 4. Iterate files ----------------------------------------------------
if tokenizing:
rows, _ = grouped_dataset.shape
for row in tqdm(range(rows)):
try:
# 4.0. Get row
row = grouped_dataset.iloc[row]
# 4.1. Get MIDI filepath
midi_path = row["midi_filepath"]
# 4.2. Get the track ID
track_id = row["track_id"]
# 4a. Build CONDITIONING prefix
genre = row["slugged_genre"]
artist = row["slugged_artist"]
year = row["year"]
prefix_ids = build_prefix(genre, artist, year, tokenizer) # list[int]
# 4b. Encode MIDI to tokens
midi = Score(midi_path)
midi_tokens = tokenizer(midi) # list[int]
# 4c. Concatenate prefix + midi + <EOS>
seq_ids = prefix_ids + midi_tokens.ids + [tokenizer.vocab["<EOS>"]]
# 4d. Save as int32 .npy
np.save(f"{out_dir}/{track_id}.npy", np.array(seq_ids, dtype=np.int32))
except Exception as e:
print(f"Error processing {midi_path}: {e}")
traceback.print_exc()
continue
The Model¶
In [92]:
class RelativePositionalEncoding(nn.Module):
"""
Sinusoidal *relative‑style* positional encoding.
The tensor it returns has the same shape as `x`
so you can just add it: x + pos(x)
Args
----
d_model : int # embedding size
max_len : int, optional # maximum sequence length
"""
def __init__(self, d_model: int, max_len: int = 2048):
super().__init__()
self.d_model = d_model
self.max_len = max_len
# Create the (max_len, d_model) sinusoid table once
position = torch.arange(max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2, dtype=torch.float)
* -(math.log(10000.0) / d_model)
)
pe = torch.zeros(max_len, d_model) # (L, D)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
# Register as a buffer so it moves with .to(device)
self.register_buffer("pe", pe) # (L, D)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Parameters
----------
x : Tensor, shape (batch, seq_len, d_model)
Returns
-------
pos : Tensor, same shape as `x`
"""
seq_len = x.size(1)
if seq_len > self.max_len:
raise ValueError(f"Sequence length {seq_len} exceeds max_len {self.max_len}")
# (1, L, D) – broadcast over batch dimension
return self.pe[:seq_len].unsqueeze(0)
In [93]:
class TransformerDecoderBlock(nn.Module):
"""
Decoder block that merges causal + pad masking into a (B×H, L, L) float mask,
so no hidden bool→float blow-ups occur.
"""
def __init__(
self,
d_model: int,
n_heads: int,
max_len: int = 2048,
dropout: float = 0.1,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(
embed_dim = d_model,
num_heads = n_heads,
dropout = dropout,
batch_first = True,
)
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.ff = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
nn.GELU(),
nn.Linear(4 * d_model, d_model),
)
self.dropout = nn.Dropout(dropout)
# Precompute float causal mask: 0 on/under diag, -inf above
causal = torch.triu(
torch.full((max_len, max_len), float("-inf")),
diagonal=1
)
self.register_buffer("causal_mask", causal, persistent=False)
def forward(
self,
x: torch.Tensor, # (B, L, D)
pad_mask: torch.Tensor=None # (B, L), True=keep token, False=pad
) -> torch.Tensor:
B, L, _ = x.shape
H = self.self_attn.num_heads
device = x.device
dtype = x.dtype
# 1) slice the (L×L) causal mask
causal = self.causal_mask[:L, :L] # float32, (L, L)
# 2) build a (B, L) float pad mask: 0 on tokens, -inf on pads
if pad_mask is not None:
pad_float = torch.zeros((B, L), device=device, dtype=dtype)
pad_float = pad_float.masked_fill(~pad_mask, float("-inf"))
# 3) expand pad_float to (B, L, L) and add causal
# pad_float.unsqueeze(1): (B, 1, L) → broadcast over src_len
attn_batch = causal.unsqueeze(0) + pad_float.unsqueeze(1) # (B, L, L)
else:
attn_batch = causal # (L, L)
# 4) if we have a batch, repeat per-head to (B×H, L, L)
if pad_mask is not None:
# attn_batch: (B, L, L) → repeat each batch H times
attn_mask = attn_batch.repeat_interleave(H, dim=0) # (B*H, L, L)
else:
attn_mask = attn_batch # 2D mask
# 5) self-attention with ONLY attn_mask
attn_out, _ = self.self_attn(
x, x, x,
attn_mask=attn_mask
)
# 6) residual + norm + feed-forward + norm
x = self.ln1(x + self.dropout(attn_out))
x = self.ln2(x + self.dropout(self.ff(x)))
return x
In [94]:
class Synphony(nn.Module):
def __init__(self, vocab_size, d_model=512, n_layers=6, n_heads=8):
super().__init__()
self.embed = nn.Embedding(vocab_size, d_model)
self.pos = RelativePositionalEncoding(d_model, max_len=2048)
self.blocks = nn.ModuleList([
TransformerDecoderBlock(d_model, n_heads) for _ in range(n_layers)
])
self.ln = nn.LayerNorm(d_model)
self.out = nn.Linear(d_model, vocab_size)
def forward(self, x, pad_mask=None):
x = self.embed(x) + self.pos(x)
for blk in self.blocks:
x = blk(x, pad_mask)
x = self.ln(x)
return self.out(x)
The Training Loop¶
In [95]:
from torch.utils.data import Dataset, DataLoader
import random
random.seed(42) # For reproducibility
In [96]:
tok_paths = []
for dirpath, dirnames, filenames in os.walk('data/tokens/train'):
for file in filenames:
full_path = os.path.join(dirpath, file)
if full_path.endswith('.npy'):
tok_paths.append(full_path)
In [97]:
len(tok_paths)
Out[97]:
6150
In [98]:
split_index = int(len(tok_paths) * 0.8) # 80% train, 20% test
random.shuffle(tok_paths)
train_paths = tok_paths[:split_index]
test_paths = tok_paths[split_index:]
In [104]:
# ─── 1. Dataset + collate ────────────────────────────────────────────────
class MidiTokenDataset(Dataset):
def __init__(self, npy_paths):
self.paths = npy_paths
def __len__(self): # number of songs in split
return len(self.paths)
def __getitem__(self, idx): # returns 1‑D np.ndarray[int]
return np.load(self.paths[idx]).astype(np.int64)
def collate_fn(batch, pad_id):
B, L = len(batch), MAX_TOKENS
x = torch.full((B, L), pad_id, dtype=torch.long)
for i, seq in enumerate(batch):
seq = torch.from_numpy(seq)
if seq.numel() > L:
start = torch.randint(0, seq.numel() - L + 1, (1,)).item()
seq = seq[start : start + L]
x[i, : seq.numel()] = seq
pad_mask = ~x.eq(pad_id)
return x, pad_mask
# ─── 2. DataLoaders ──────────────────────────────────────────────────────
PAD_ID = tokenizer.vocab['<PAD>'] # or use the ID you chose for <PAD>
train_ds = MidiTokenDataset(train_paths)
val_ds = MidiTokenDataset(test_paths)
train_loader = DataLoader(
train_ds, batch_size=BATCH_SIZE, shuffle=True,
collate_fn=lambda b: collate_fn(b, PAD_ID)
)
val_loader = DataLoader(
val_ds, batch_size=BATCH_SIZE, shuffle=False,
collate_fn=lambda b: collate_fn(b, PAD_ID)
)
# ─── 3. Model, optimiser, scheduler ─────────────────────────────────────
device = (
"cuda" if torch.cuda.is_available()
else "mps" if torch.backends.mps.is_available()
else "cpu"
)
model = Synphony(
vocab_size=len(tokenizer), d_model=D_MODEL,
n_layers=N_LAYERS, n_heads=N_HEADS).to(device)
# 1. Switch to AdamW with weight decay
optim = torch.optim.AdamW(model.parameters(),
lr=3e-4, # whatever your current LR is
weight_decay=1e-2) # small wd to regularize
# 2. Set up a Reduce-on-Plateau scheduler
sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optim,
mode='min', # val loss should go down
factor=0.5, # cut LR in half
patience=2, # wait 2 epochs
min_lr=1e-6, # floor on LR
verbose=True)
# ─── 4. Training loop ────────────────────────────────────────────────────
best_val_loss = float("inf")
for epoch in tqdm(range(1, 51)): # 50 epochs
# ---- train ----------------------------------------------------------
model.train()
running_loss = 0.0
for x, pad_mask in train_loader: # pad_mask: (B, L)
x, pad_mask = x.to(device), pad_mask.to(device)
logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
loss = F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
x[:, 1:].reshape(-1),
ignore_index=PAD_ID,
label_smoothing=0.1
)
optim.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optim.step()
running_loss += loss.item()
train_ppl = math.exp(running_loss / len(train_loader))
# ---- validation -----------------------------------------------------
model.eval()
val_loss = 0.0
with torch.no_grad():
for x, pad_mask in val_loader: # pad_mask is (B, L)
x, pad_mask = x.to(device), pad_mask.to(device)
# exactly like in training
logits = model(x[:, :-1], pad_mask=pad_mask[:, :-1])
val_loss += F.cross_entropy(
logits.reshape(-1, logits.size(-1)),
x[:, 1:].reshape(-1),
ignore_index=PAD_ID
).item()
val_ppl = math.exp(val_loss / len(val_loader))
print(f"val PPL {val_ppl:6.2f}")
print(f"Epoch {epoch:02d} ▸ train PPL {train_ppl:6.2f} | val PPL {val_ppl:6.2f}")
# ---- scheduler step -----------------------------------------------
sched.step(val_loss / len(val_loader)) # pass your avg val_loss
# log current LR
current_lr = optim.param_groups[0]['lr']
print(f" ↳ LR now = {current_lr:.2e}")
# ---- checkpoint -----------------------------------------------------
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "synphony_best.pt")
print(" ✓ new best model saved")
print("Done!")
0%| | 0/50 [00:00<?, ?it/s]
val PPL 12.04
Epoch 01 ▸ train PPL 41.38 | val PPL 12.04
↳ LR now = 3.00e-04
2%|▏ | 1/50 [08:14<6:43:27, 494.03s/it]
✓ new best model saved
val PPL 6.79
Epoch 02 ▸ train PPL 21.59 | val PPL 6.79
↳ LR now = 3.00e-04
4%|▍ | 2/50 [16:31<6:36:43, 495.90s/it]
✓ new best model saved
val PPL 4.01
Epoch 03 ▸ train PPL 12.46 | val PPL 4.01
↳ LR now = 3.00e-04
6%|▌ | 3/50 [24:48<6:28:57, 496.54s/it]
✓ new best model saved
val PPL 3.48
Epoch 04 ▸ train PPL 9.59 | val PPL 3.48
↳ LR now = 3.00e-04
8%|▊ | 4/50 [33:05<6:20:55, 496.85s/it]
✓ new best model saved
val PPL 3.28
Epoch 05 ▸ train PPL 8.84 | val PPL 3.28
↳ LR now = 3.00e-04
10%|█ | 5/50 [41:23<6:12:51, 497.16s/it]
✓ new best model saved
val PPL 3.17
Epoch 06 ▸ train PPL 8.44 | val PPL 3.17
↳ LR now = 3.00e-04
12%|█▏ | 6/50 [49:40<6:04:39, 497.25s/it]
✓ new best model saved
val PPL 3.06
Epoch 07 ▸ train PPL 8.14 | val PPL 3.06
↳ LR now = 3.00e-04
14%|█▍ | 7/50 [57:58<5:56:26, 497.36s/it]
✓ new best model saved
val PPL 3.03
Epoch 08 ▸ train PPL 7.97 | val PPL 3.03
↳ LR now = 3.00e-04
16%|█▌ | 8/50 [1:06:15<5:48:09, 497.36s/it]
✓ new best model saved
val PPL 2.97
Epoch 09 ▸ train PPL 7.81 | val PPL 2.97
↳ LR now = 3.00e-04
18%|█▊ | 9/50 [1:14:33<5:39:53, 497.39s/it]
✓ new best model saved
val PPL 2.92
Epoch 10 ▸ train PPL 7.64 | val PPL 2.92
↳ LR now = 3.00e-04
20%|██ | 10/50 [1:22:50<5:31:34, 497.37s/it]
✓ new best model saved
val PPL 2.87
Epoch 11 ▸ train PPL 7.55 | val PPL 2.87
↳ LR now = 3.00e-04
22%|██▏ | 11/50 [1:31:08<5:23:17, 497.38s/it]
✓ new best model saved
24%|██▍ | 12/50 [1:39:24<5:14:51, 497.14s/it]
val PPL 2.87
Epoch 12 ▸ train PPL 7.42 | val PPL 2.87
↳ LR now = 3.00e-04
val PPL 2.82
Epoch 13 ▸ train PPL 7.35 | val PPL 2.82
↳ LR now = 3.00e-04
26%|██▌ | 13/50 [1:47:42<5:06:43, 497.40s/it]
✓ new best model saved
val PPL 2.80
Epoch 14 ▸ train PPL 7.28 | val PPL 2.80
↳ LR now = 3.00e-04
28%|██▊ | 14/50 [1:56:00<4:58:31, 497.54s/it]
✓ new best model saved
val PPL 2.77
Epoch 15 ▸ train PPL 7.23 | val PPL 2.77
↳ LR now = 3.00e-04
30%|███ | 15/50 [2:04:18<4:50:18, 497.66s/it]
✓ new best model saved
val PPL 2.76
Epoch 16 ▸ train PPL 7.15 | val PPL 2.76
↳ LR now = 3.00e-04
32%|███▏ | 16/50 [2:12:36<4:42:01, 497.70s/it]
✓ new best model saved
val PPL 2.74
Epoch 17 ▸ train PPL 7.10 | val PPL 2.74
↳ LR now = 3.00e-04
34%|███▍ | 17/50 [2:20:54<4:33:44, 497.73s/it]
✓ new best model saved
val PPL 2.71
Epoch 18 ▸ train PPL 7.05 | val PPL 2.71
↳ LR now = 3.00e-04
36%|███▌ | 18/50 [2:29:11<4:25:23, 497.62s/it]
✓ new best model saved
val PPL 2.67
Epoch 19 ▸ train PPL 7.00 | val PPL 2.67
↳ LR now = 3.00e-04
38%|███▊ | 19/50 [2:37:29<4:17:08, 497.69s/it]
✓ new best model saved
40%|████ | 20/50 [2:45:45<4:08:35, 497.17s/it]
val PPL 2.70
Epoch 20 ▸ train PPL 6.96 | val PPL 2.70
↳ LR now = 3.00e-04
42%|████▏ | 21/50 [2:54:01<4:00:11, 496.95s/it]
val PPL 2.68
Epoch 21 ▸ train PPL 6.92 | val PPL 2.68
↳ LR now = 3.00e-04
val PPL 2.66
Epoch 22 ▸ train PPL 6.88 | val PPL 2.66
↳ LR now = 3.00e-04
44%|████▍ | 22/50 [3:02:19<3:52:00, 497.15s/it]
✓ new best model saved
46%|████▌ | 23/50 [3:10:35<3:43:35, 496.88s/it]
val PPL 2.66
Epoch 23 ▸ train PPL 6.82 | val PPL 2.66
↳ LR now = 3.00e-04
val PPL 2.63
Epoch 24 ▸ train PPL 6.80 | val PPL 2.63
↳ LR now = 3.00e-04
48%|████▊ | 24/50 [3:18:53<3:35:24, 497.11s/it]
✓ new best model saved
50%|█████ | 25/50 [3:27:09<3:26:59, 496.79s/it]
val PPL 2.65
Epoch 25 ▸ train PPL 6.79 | val PPL 2.65
↳ LR now = 3.00e-04
val PPL 2.62
Epoch 26 ▸ train PPL 6.75 | val PPL 2.62
↳ LR now = 3.00e-04
52%|█████▏ | 26/50 [3:35:26<3:18:49, 497.05s/it]
✓ new best model saved
54%|█████▍ | 27/50 [3:43:42<3:10:24, 496.72s/it]
val PPL 2.63
Epoch 27 ▸ train PPL 6.73 | val PPL 2.63
↳ LR now = 3.00e-04
val PPL 2.61
Epoch 28 ▸ train PPL 6.68 | val PPL 2.61
↳ LR now = 3.00e-04
56%|█████▌ | 28/50 [3:52:00<3:02:12, 496.93s/it]
✓ new best model saved
58%|█████▊ | 29/50 [4:00:16<2:53:51, 496.73s/it]
val PPL 2.61
Epoch 29 ▸ train PPL 6.67 | val PPL 2.61
↳ LR now = 3.00e-04
val PPL 2.60
Epoch 30 ▸ train PPL 6.64 | val PPL 2.60
↳ LR now = 3.00e-04
60%|██████ | 30/50 [4:08:34<2:45:42, 497.11s/it]
✓ new best model saved
val PPL 2.57
Epoch 31 ▸ train PPL 6.65 | val PPL 2.57
↳ LR now = 3.00e-04
62%|██████▏ | 31/50 [4:16:52<2:37:28, 497.31s/it]
✓ new best model saved
64%|██████▍ | 32/50 [4:25:08<2:29:05, 496.96s/it]
val PPL 2.59
Epoch 32 ▸ train PPL 6.59 | val PPL 2.59
↳ LR now = 3.00e-04
66%|██████▌ | 33/50 [4:33:24<2:20:44, 496.72s/it]
val PPL 2.58
Epoch 33 ▸ train PPL 6.57 | val PPL 2.58
↳ LR now = 3.00e-04
val PPL 2.55
Epoch 34 ▸ train PPL 6.56 | val PPL 2.55
↳ LR now = 3.00e-04
68%|██████▊ | 34/50 [4:41:42<2:12:32, 497.02s/it]
✓ new best model saved
70%|███████ | 35/50 [4:49:58<2:04:11, 496.74s/it]
val PPL 2.57
Epoch 35 ▸ train PPL 6.53 | val PPL 2.57
↳ LR now = 3.00e-04
72%|███████▏ | 36/50 [4:58:14<1:55:51, 496.52s/it]
val PPL 2.57
Epoch 36 ▸ train PPL 6.53 | val PPL 2.57
↳ LR now = 3.00e-04
val PPL 2.55
Epoch 37 ▸ train PPL 6.53 | val PPL 2.55
↳ LR now = 3.00e-04
74%|███████▍ | 37/50 [5:06:31<1:47:38, 496.79s/it]
✓ new best model saved
val PPL 2.54
Epoch 38 ▸ train PPL 6.49 | val PPL 2.54
↳ LR now = 3.00e-04
76%|███████▌ | 38/50 [5:14:49<1:39:24, 497.00s/it]
✓ new best model saved
78%|███████▊ | 39/50 [5:23:05<1:31:03, 496.68s/it]
val PPL 2.55
Epoch 39 ▸ train PPL 6.45 | val PPL 2.55
↳ LR now = 3.00e-04
80%|████████ | 40/50 [5:31:21<1:22:44, 496.46s/it]
val PPL 2.55
Epoch 40 ▸ train PPL 6.45 | val PPL 2.55
↳ LR now = 3.00e-04
val PPL 2.54
Epoch 41 ▸ train PPL 6.44 | val PPL 2.54
↳ LR now = 3.00e-04
82%|████████▏ | 41/50 [5:39:38<1:14:30, 496.69s/it]
✓ new best model saved
val PPL 2.52
Epoch 42 ▸ train PPL 6.40 | val PPL 2.52
↳ LR now = 3.00e-04
84%|████████▍ | 42/50 [5:47:55<1:06:15, 496.91s/it]
✓ new best model saved
86%|████████▌ | 43/50 [5:56:11<57:56, 496.60s/it]
val PPL 2.52
Epoch 43 ▸ train PPL 6.41 | val PPL 2.52
↳ LR now = 3.00e-04
88%|████████▊ | 44/50 [6:04:27<49:38, 496.37s/it]
val PPL 2.53
Epoch 44 ▸ train PPL 6.38 | val PPL 2.53
↳ LR now = 3.00e-04
val PPL 2.52
Epoch 45 ▸ train PPL 6.38 | val PPL 2.52
↳ LR now = 1.50e-04
90%|█████████ | 45/50 [6:12:45<41:23, 496.68s/it]
✓ new best model saved
val PPL 2.46
Epoch 46 ▸ train PPL 6.20 | val PPL 2.46
↳ LR now = 1.50e-04
92%|█████████▏| 46/50 [6:21:02<33:07, 496.86s/it]
✓ new best model saved
val PPL 2.45
Epoch 47 ▸ train PPL 6.12 | val PPL 2.45
↳ LR now = 1.50e-04
94%|█████████▍| 47/50 [6:29:19<24:50, 496.93s/it]
✓ new best model saved
96%|█████████▌| 48/50 [6:37:35<16:33, 496.73s/it]
val PPL 2.46
Epoch 48 ▸ train PPL 6.09 | val PPL 2.46
↳ LR now = 1.50e-04
98%|█████████▊| 49/50 [6:45:51<08:16, 496.51s/it]
val PPL 2.45
Epoch 49 ▸ train PPL 6.06 | val PPL 2.45
↳ LR now = 1.50e-04
val PPL 2.43
Epoch 50 ▸ train PPL 6.04 | val PPL 2.43
↳ LR now = 1.50e-04
100%|██████████| 50/50 [6:54:09<00:00, 496.99s/it]
✓ new best model saved Done!
In [72]:
tokenizer.vocab_size
Out[72]:
3534
3. Model Inference¶
In [105]:
model.eval()
Out[105]:
Synphony(
(embed): Embedding(3534, 768)
(pos): RelativePositionalEncoding()
(blocks): ModuleList(
(0-7): 8 x TransformerDecoderBlock(
(self_attn): MultiheadAttention(
(out_proj): NonDynamicallyQuantizableLinear(in_features=768, out_features=768, bias=True)
)
(ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(ff): Sequential(
(0): Linear(in_features=768, out_features=3072, bias=True)
(1): GELU(approximate='none')
(2): Linear(in_features=3072, out_features=768, bias=True)
)
(dropout): Dropout(p=0.1, inplace=False)
)
)
(ln): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
(out): Linear(in_features=768, out_features=3534, bias=True)
)
In [106]:
TEMPERATURE = 1.0
TOP_K = 8
# ─── 2. Helper for top-k filtering ───────────────────────────────────────
def top_k_logits(logits, k):
v, _ = torch.topk(logits, k)
threshold = v[-1]
return torch.where(logits < threshold, torch.full_like(logits, -float("Inf")), logits)
# ─── 3. Autoregressive generation ────────────────────────────────────────
@torch.no_grad()
def generate(
genre:str,
artist:str,
year:int,
max_length:int = MAX_TOKENS
) -> list[int]:
prefix = build_prefix(genre, artist, year, tokenizer)
input_ids = torch.tensor([prefix], device=device) # (1, P)
pad_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
for _ in tqdm(range(max_length - len(prefix))):
logits = model(input_ids, pad_mask=pad_mask)
next_logits = logits[0, -1, :] # (V,)
next_logits = next_logits / TEMPERATURE
next_logits = top_k_logits(next_logits, TOP_K)
probs = F.softmax(next_logits, dim=-1)
next_id = torch.multinomial(probs, num_samples=1) # (1,)
if next_id.item() == tokenizer.vocab["<EOS>"]:
break
# append and extend pad_mask
input_ids = torch.cat([input_ids, next_id.unsqueeze(0)], dim=1) # (1, L+1)
pad_mask = torch.ones_like(input_ids, dtype=torch.bool, device=device)
return input_ids[0].tolist()
# ─── 4. Decode to MIDI & save ────────────────────────────────────────────
def tokens_to_midi(token_ids: list[int], out_path: str):
"""
Drop the 3 metadata tokens + optional EOS, then decode the rest.
"""
# 1) drop the first 3 prefix IDs (genre, artist, year)
musical_ids = token_ids[3:]
# 2) drop trailing <EOS> if present
eos_id = tokenizer.vocab["<EOS>"]
if len(musical_ids) > 0 and musical_ids[-1] == eos_id:
musical_ids = musical_ids[:-1]
# 3) decode only the musical tokens back to a PrettyMIDI
pm = tokenizer(musical_ids)
# 4) write out the .mid file
pm.dump_midi(out_path)
In [120]:
# ─── 5. Run it! ───────────────────────────────────────────────────────────
# Example user inputs
genre_input = "ROCK"
artist_input = "GLORIA_GAYNOR"
year_input = 1990
gen_ids = generate(genre_input, artist_input, year_input, max_length=512)
out_file = "generated.mid"
tokens_to_midi(gen_ids, out_file)
print(f"🎹 Wrote MIDI to {out_file}")
100%|██████████| 509/509 [00:03<00:00, 131.18it/s]
🎹 Wrote MIDI to generated.mid
In [121]:
from midi2audio import FluidSynth
from IPython.display import Audio
# render your MIDI to a WAV
fs = FluidSynth()
fs.midi_to_audio('generated.mid', 'generated.wav')
# now embed the WAV inline
Audio('generated.wav')
Parameter '/home/jupyter/.fluidsynth/default_sound_font.sf2' not a SoundFont or MIDI file or error occurred identifying it.
FluidSynth runtime version 2.1.7 Copyright (C) 2000-2021 Peter Hanappe and others. Distributed under the LGPL license. SoundFont(R) is a registered trademark of E-mu Systems, Inc. Rendering audio to file 'generated.wav'..
Out[121]:
In [ ]: